{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "a8ef9ad9-c7e1-4fed-b0bd-581064558089",
   "metadata": {},
   "source": [
    "# GPT-3.5-Turbo Performance on MMLU - Astronomy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "4f1c09a4-4859-469b-a156-dbb037c83a65",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import openai\n",
    "import re\n",
    "import time\n",
    "import json\n",
    "\n",
    "import numpy as np\n",
    "\n",
    "from tqdm import tqdm\n",
    "from datasets import load_dataset\n",
    "from tenacity import retry, stop_after_attempt, wait_chain, wait_fixed"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "eaea2a2a-2515-4508-9cdb-084d10853170",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "openai.api_key = \"sk-\" "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "95ddd688-5bf5-40c5-a852-32e62f5a1bbb",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "@retry(wait=wait_chain(*[wait_fixed(3) for i in range(3)] +\n",
    "                       [wait_fixed(5) for i in range(2)] +\n",
    "                       [wait_fixed(10)]))\n",
    "def completion_with_backoff(**kwargs):\n",
    "    return openai.ChatCompletion.create(**kwargs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "d4063503-c0e0-4df7-9866-0815f4c9fdf0",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "mmlu_prompt = json.load(open('lib_prompt/mmlu-cot.json'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "bbbbe828-fad8-48a9-9cab-4a7e675ecf9e",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The following are multiple choice questions (with answers) about astronomy.\n",
      "\n",
      "Q: Where do most short-period comets come from and how do we know?\n",
      "(A) The Kuiper belt; short period comets tend to be in the plane of the solar system just like the Kuiper belt. (B) The Kuiper belt; short period comets tend to come from random directions indicating a spherical distribution of comets called the Kuiper belt. (C) The asteroid belt; short period comets have orbital periods similar to asteroids like Vesta and are found in the plane of the solar system just like the asteroid belt. (D) The Oort cloud; short period comets tend to be in the plane of the solar system just like the Oort cloud.\n",
      "A: Let's think step by step. Most short-period comets come from the Kuiper belt, and we know because short period coments tend to be in the plane of the solar system, just like the Kuiper belt is. The answer is (A).\n",
      "\n",
      "Q: You are pushing a truck along a road. Would it be easier to accelerate this truck on Mars? Why? (Assume there is no friction)\n",
      "(A) It would be harder since the truck is heavier on Mars. (B) It would be easier since the truck is lighter on Mars. (C) It would be harder since the truck is lighter on Mars. (D) It would be the same no matter where you are.\n",
      "A: Let's think step by step. If we assume that there is no friction, the force needed to accelerate the truck is by Newton’s second law only dependent on the mass of the truck. Hence (A), (B) and (C) are incorrect since it doesn’t matter that it’s on Mars, and (D) is the correct answer. The answer is (D).\n",
      "\n",
      "Q: Say the pupil of your eye has a diameter of 5 mm and you have a telescope with an aperture of 50 cm. How much more light can the telescope gather than your eye?\n",
      "(A) 10000 times more (B) 100 times more (C) 1000 times more (D) 10 times more\n",
      "A: Let's think step by step. The amount of light is proportional to the aperture area $A = \\pi D^2/4$ for a lens with diameter $D$, so the relative amounts of light between the eye with diameter 5mm and the telescope with diameter 50mm is $(50 cm)^2/(5mm)^2 = 10000$. The answer is (A).\n",
      "\n",
      "Q: Why isn't there a planet where the asteroid belt is located?\n",
      "(A) A planet once formed here but it was broken apart by a catastrophic collision. (B) There was not enough material in this part of the solar nebula to form a planet. (C) There was too much rocky material to form a terrestrial planet but not enough gaseous material to form a jovian planet. (D) Resonance with Jupiter prevented material from collecting together to form a planet.\n",
      "A: Let's think step by step. The asteroid belt is a stellar disc consisting of a large number of asteroids between Mars and Jupiter's orbits. The asteroids in this belt are affected by the gravitational pull from both other asteroids and nearby planets. Due to the strong gravitational force of Jupiter there are resonances that give rise to low density regions of asteroids known as the Kirkwood gap. So (B) and (C) are not correct since it’s not a lack of material that prevents a planet from being formed, and (A) is incorrect because the Kirkwood gap would have prevented a planet from forming in the first place, and (D) is the correct option. The answer is (D).\n",
      "\n",
      "Q: Why is Mars red?\n",
      "(A) Because the surface is covered with heavily oxidized (\"rusted\") minerals. (B) Because the atmosphere scatters more light at bluer wavelengths transmitting mostly red light. (C) Because Mars is covered with ancient lava flows which are red in color. (D) Because flowing water on Mars's surface altered the surface minerals several billion years ago.\n",
      "A: Let's think step by step. Option (B) is not correct because if the red color was caused by the scattering off the atmosphere, then the earth with a much thicker atmosphere would also look red. Options (C) and (D) are not specific enough about why the color of the surface would be red, while (A) is correct because it explains that the surface is red due to the rusted materials on the surface and the red color comes from the rust. So the correct option is (A). The answer is (A).\n"
     ]
    }
   ],
   "source": [
    "task = 'astronomy'\n",
    "print(mmlu_prompt[task])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "ce1c6b79-3530-4efe-bfd6-eead27d09ff3",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Downloading and preparing dataset mmlu/astronomy to /Users/yaofu/.cache/huggingface/datasets/lukaemon___mmlu/astronomy/1.0.0/134145dc2582b9a08b42d1f4b828f84a0066e9cc2e7dd8c1d83bee475746ecc3...\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "2758625d83e04303826a0be5d0edac85",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Generating test split:   0%|          | 0/151 [00:00<?, ? examples/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Generating validation split:   0%|          | 0/15 [00:00<?, ? examples/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "3f5cc87ee4c44676aff88c748c6991e2",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Generating train split:   0%|          | 0/4 [00:00<?, ? examples/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Dataset mmlu downloaded and prepared to /Users/yaofu/.cache/huggingface/datasets/lukaemon___mmlu/astronomy/1.0.0/134145dc2582b9a08b42d1f4b828f84a0066e9cc2e7dd8c1d83bee475746ecc3. Subsequent calls will reuse this data.\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "3149be1837d54e1b95f82f89b09b1065",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/3 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "task_data = load_dataset(\"lukaemon/mmlu\", task)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "0957327f-9454-4010-9501-b7086fbba125",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "151"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(task_data['test'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "5b18d727-13a7-45cd-a842-3462636d9c25",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'input': 'If you know both the actual brightness of an object and its apparent brightness from your location then with no other information you can estimate:',\n",
       " 'A': 'Its speed relative to you',\n",
       " 'B': 'Its composition',\n",
       " 'C': 'Its size',\n",
       " 'D': 'Its distance from you',\n",
       " 'target': 'D'}"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "task_data['test'][0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "bec64ac3-96b4-45cd-b79e-6c8ad6fae234",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "prompt_q = mmlu_prompt[task] + \"\\n\\n\" + task_data['test'][0]['input'] + '\\n'\n",
    "for letter in ['A', 'B', 'C', 'D']:\n",
    "    prompt_q += '(' + letter + ') ' + task_data['test'][0][letter] + ' '\n",
    "prompt_q += \"\\nA: Let's think step by step.\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "8847b4b5-99fe-47ab-941b-5bd02c43755e",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The following are multiple choice questions (with answers) about astronomy.\n",
      "\n",
      "Q: Where do most short-period comets come from and how do we know?\n",
      "(A) The Kuiper belt; short period comets tend to be in the plane of the solar system just like the Kuiper belt. (B) The Kuiper belt; short period comets tend to come from random directions indicating a spherical distribution of comets called the Kuiper belt. (C) The asteroid belt; short period comets have orbital periods similar to asteroids like Vesta and are found in the plane of the solar system just like the asteroid belt. (D) The Oort cloud; short period comets tend to be in the plane of the solar system just like the Oort cloud.\n",
      "A: Let's think step by step. Most short-period comets come from the Kuiper belt, and we know because short period coments tend to be in the plane of the solar system, just like the Kuiper belt is. The answer is (A).\n",
      "\n",
      "Q: You are pushing a truck along a road. Would it be easier to accelerate this truck on Mars? Why? (Assume there is no friction)\n",
      "(A) It would be harder since the truck is heavier on Mars. (B) It would be easier since the truck is lighter on Mars. (C) It would be harder since the truck is lighter on Mars. (D) It would be the same no matter where you are.\n",
      "A: Let's think step by step. If we assume that there is no friction, the force needed to accelerate the truck is by Newton’s second law only dependent on the mass of the truck. Hence (A), (B) and (C) are incorrect since it doesn’t matter that it’s on Mars, and (D) is the correct answer. The answer is (D).\n",
      "\n",
      "Q: Say the pupil of your eye has a diameter of 5 mm and you have a telescope with an aperture of 50 cm. How much more light can the telescope gather than your eye?\n",
      "(A) 10000 times more (B) 100 times more (C) 1000 times more (D) 10 times more\n",
      "A: Let's think step by step. The amount of light is proportional to the aperture area $A = \\pi D^2/4$ for a lens with diameter $D$, so the relative amounts of light between the eye with diameter 5mm and the telescope with diameter 50mm is $(50 cm)^2/(5mm)^2 = 10000$. The answer is (A).\n",
      "\n",
      "Q: Why isn't there a planet where the asteroid belt is located?\n",
      "(A) A planet once formed here but it was broken apart by a catastrophic collision. (B) There was not enough material in this part of the solar nebula to form a planet. (C) There was too much rocky material to form a terrestrial planet but not enough gaseous material to form a jovian planet. (D) Resonance with Jupiter prevented material from collecting together to form a planet.\n",
      "A: Let's think step by step. The asteroid belt is a stellar disc consisting of a large number of asteroids between Mars and Jupiter's orbits. The asteroids in this belt are affected by the gravitational pull from both other asteroids and nearby planets. Due to the strong gravitational force of Jupiter there are resonances that give rise to low density regions of asteroids known as the Kirkwood gap. So (B) and (C) are not correct since it’s not a lack of material that prevents a planet from being formed, and (A) is incorrect because the Kirkwood gap would have prevented a planet from forming in the first place, and (D) is the correct option. The answer is (D).\n",
      "\n",
      "Q: Why is Mars red?\n",
      "(A) Because the surface is covered with heavily oxidized (\"rusted\") minerals. (B) Because the atmosphere scatters more light at bluer wavelengths transmitting mostly red light. (C) Because Mars is covered with ancient lava flows which are red in color. (D) Because flowing water on Mars's surface altered the surface minerals several billion years ago.\n",
      "A: Let's think step by step. Option (B) is not correct because if the red color was caused by the scattering off the atmosphere, then the earth with a much thicker atmosphere would also look red. Options (C) and (D) are not specific enough about why the color of the surface would be red, while (A) is correct because it explains that the surface is red due to the rusted materials on the surface and the red color comes from the rust. So the correct option is (A). The answer is (A).\n",
      "\n",
      "If you know both the actual brightness of an object and its apparent brightness from your location then with no other information you can estimate:\n",
      "(A) Its speed relative to you (B) Its composition (C) Its size (D) Its distance from you \n",
      "A: Let's think step by step.\n"
     ]
    }
   ],
   "source": [
    "print(prompt_q)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "c91e1e78-0e6d-4d82-9600-b6c76637fa80",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "response = openai.ChatCompletion.create(\n",
    "    model=\"gpt-3.5-turbo\",\n",
    "    messages=[\n",
    "        {\"role\": \"system\", \"content\": \"Follow the given examples and answer the question.\"},\n",
    "        {\"role\": \"user\", \"content\": prompt_q},\n",
    "    ],\n",
    "    temperature=0, \n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "b1e7b952-30b3-43b3-aa11-06595a7f592c",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The actual brightness of an object is its intrinsic brightness, while the apparent brightness is how bright it appears from a certain location. The difference between the two is known as the distance modulus, which is related to the distance to the object. Therefore, if you know both the actual and apparent brightness of an object, you can estimate its distance from you. The answer is (D).\n"
     ]
    }
   ],
   "source": [
    "print(response['choices'][0]['message']['content'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "867945a0-6527-409e-801e-4244eaf75e61",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "def test_answer_mmlu(pred_str, ans_str):\n",
    "    pattern = 'the answer is ('\n",
    "    pred = pred_str.lower().split(pattern)\n",
    "    \n",
    "    if(len(pred) > 1):\n",
    "        # print(pred)\n",
    "        pred = pred[1][0]\n",
    "        gold = ans_str.split('A:\\n')[1][0].lower()\n",
    "        # print('debug 1, pred %s, gold %s' % (pred, gold))\n",
    "        return pred == gold\n",
    "    else: \n",
    "        pred = 'C'\n",
    "        gold = ans_str.split('A:\\n')[1][0].lower()\n",
    "        # print('debug 2, pred %s, gold %s' % (pred, gold))\n",
    "        return pred == gold\n",
    "\n",
    "def parse_pred_ans(filename):\n",
    "    with open(filename) as fd: lines = fd.readlines()\n",
    "    am, a = None, None\n",
    "    num_q, acc = 0, 0\n",
    "    current_mode = 'none'\n",
    "    questions = []\n",
    "    ans_pred = []\n",
    "    ans_gold = []\n",
    "    for l in lines:\n",
    "        if(l.startswith('Q: ')):\n",
    "            if(am is not None and a is not None):\n",
    "                questions.append(q)\n",
    "                ans_pred.append(am)\n",
    "                ans_gold.append(a)\n",
    "                # print(am)\n",
    "                # print(a)\n",
    "                if(test_answer_mmlu(am, a)):\n",
    "                    acc += 1\n",
    "            current_mode = 'q'\n",
    "            q = l\n",
    "            num_q += 1\n",
    "        elif(l.startswith('A_model:')):\n",
    "            current_mode = 'am'\n",
    "            am = l\n",
    "        elif(l.startswith('A:')):\n",
    "            current_mode = 'a'\n",
    "            a = l\n",
    "        else:\n",
    "            if(current_mode == 'q'): q += l\n",
    "            elif(current_mode == 'am'): am += l\n",
    "            elif(current_mode == 'a'): a += l\n",
    "            else:\n",
    "                raise ValueError(current_mode)\n",
    "                \n",
    "    questions.append(q)\n",
    "    ans_pred.append(am)\n",
    "    ans_gold.append(a)\n",
    "    # print(am)\n",
    "    # print(a)\n",
    "    if(test_answer_mmlu(am, a)):\n",
    "        acc += 1\n",
    "    print('num_q %d correct %d ratio %.4f' % (num_q, acc, float(acc / num_q)))\n",
    "    return questions, ans_pred, ans_gold\n",
    "\n",
    "def test_finished(ans_model):\n",
    "    if('answer is' in ans_model): return True\n",
    "    else: return False\n",
    "\n",
    "def extract_ans(ans_model):\n",
    "    ans_model = ans_model.split('\\n')\n",
    "    ans = []\n",
    "    residual = []\n",
    "    for li, al in enumerate(ans_model):\n",
    "        ans.append(al)\n",
    "        if('answer is' in al):\n",
    "            break\n",
    "    residual = list(ans_model[li + 1:])\n",
    "    ans = '\\n'.join(ans)\n",
    "    residual = '\\n'.join(residual)\n",
    "    return ans, residual"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "b398b88e-d8fd-47fc-be03-f4627f6719e1",
   "metadata": {
    "collapsed": true,
    "jupyter": {
     "outputs_hidden": true
    },
    "tags": []
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n",
      "  0%|                                                                                                                                                                                       | 0/151 [00:00<?, ?it/s]\u001b[A\n",
      "  1%|█▏                                                                                                                                                                             | 1/151 [00:01<04:46,  1.91s/it]\u001b[A\n",
      "  1%|██▎                                                                                                                                                                            | 2/151 [00:05<06:34,  2.65s/it]\u001b[A\n",
      "  2%|███▍                                                                                                                                                                           | 3/151 [00:08<07:09,  2.90s/it]\u001b[A\n",
      "  3%|████▋                                                                                                                                                                          | 4/151 [00:10<06:11,  2.52s/it]\u001b[A\n",
      "  3%|█████▊                                                                                                                                                                         | 5/151 [00:14<07:32,  3.10s/it]\u001b[A\n",
      "  4%|██████▉                                                                                                                                                                        | 6/151 [00:16<07:00,  2.90s/it]\u001b[A\n",
      "  5%|████████                                                                                                                                                                       | 7/151 [00:18<05:54,  2.46s/it]\u001b[A\n",
      "  5%|█████████▎                                                                                                                                                                     | 8/151 [00:20<05:23,  2.26s/it]\u001b[A\n",
      "  6%|██████████▍                                                                                                                                                                    | 9/151 [00:22<05:29,  2.32s/it]\u001b[A\n",
      "  7%|███████████▌                                                                                                                                                                  | 10/151 [00:25<05:50,  2.48s/it]\u001b[A\n",
      "  7%|████████████▋                                                                                                                                                                 | 11/151 [00:27<05:39,  2.43s/it]\u001b[A\n",
      "  8%|█████████████▊                                                                                                                                                                | 12/151 [00:29<05:04,  2.19s/it]\u001b[A\n",
      "  9%|██████████████▉                                                                                                                                                               | 13/151 [00:31<04:48,  2.09s/it]\u001b[A\n",
      "  9%|████████████████▏                                                                                                                                                             | 14/151 [00:32<04:20,  1.90s/it]\u001b[A\n",
      " 10%|█████████████████▎                                                                                                                                                            | 15/151 [00:33<03:42,  1.64s/it]\u001b[A\n",
      " 11%|██████████████████▍                                                                                                                                                           | 16/151 [00:36<04:30,  2.01s/it]\u001b[A\n",
      " 11%|███████████████████▌                                                                                                                                                          | 17/151 [00:39<05:01,  2.25s/it]\u001b[A\n",
      " 12%|████████████████████▋                                                                                                                                                         | 18/151 [00:41<04:34,  2.07s/it]\u001b[A\n",
      " 13%|█████████████████████▉                                                                                                                                                        | 19/151 [00:43<04:28,  2.03s/it]\u001b[A\n",
      " 13%|███████████████████████                                                                                                                                                       | 20/151 [00:44<03:46,  1.73s/it]\u001b[A\n",
      " 14%|████████████████████████▏                                                                                                                                                     | 21/151 [00:46<04:04,  1.88s/it]\u001b[A\n",
      " 15%|█████████████████████████▎                                                                                                                                                    | 22/151 [00:48<04:15,  1.98s/it]\u001b[A\n",
      " 15%|██████████████████████████▌                                                                                                                                                   | 23/151 [00:51<04:51,  2.28s/it]\u001b[A\n",
      " 16%|███████████████████████████▋                                                                                                                                                  | 24/151 [00:53<04:25,  2.09s/it]\u001b[A\n",
      " 17%|████████████████████████████▊                                                                                                                                                 | 25/151 [00:54<03:58,  1.89s/it]\u001b[A\n",
      " 17%|█████████████████████████████▉                                                                                                                                                | 26/151 [00:57<04:34,  2.19s/it]\u001b[A\n",
      " 18%|███████████████████████████████                                                                                                                                               | 27/151 [00:59<04:34,  2.21s/it]\u001b[A\n",
      " 19%|████████████████████████████████▎                                                                                                                                             | 28/151 [01:01<04:21,  2.13s/it]\u001b[A\n",
      " 19%|█████████████████████████████████▍                                                                                                                                            | 29/151 [01:03<04:10,  2.06s/it]\u001b[A\n",
      " 20%|██████████████████████████████████▌                                                                                                                                           | 30/151 [01:06<04:25,  2.19s/it]\u001b[A\n",
      " 21%|███████████████████████████████████▋                                                                                                                                          | 31/151 [01:10<05:31,  2.77s/it]\u001b[A\n",
      " 21%|████████████████████████████████████▊                                                                                                                                         | 32/151 [01:11<04:30,  2.27s/it]\u001b[A\n",
      " 22%|██████████████████████████████████████                                                                                                                                        | 33/151 [01:13<04:24,  2.24s/it]\u001b[A\n",
      " 23%|███████████████████████████████████████▏                                                                                                                                      | 34/151 [01:14<03:47,  1.94s/it]\u001b[A\n",
      " 23%|████████████████████████████████████████▎                                                                                                                                     | 35/151 [01:16<03:33,  1.84s/it]\u001b[A\n",
      " 24%|█████████████████████████████████████████▍                                                                                                                                    | 36/151 [01:17<03:17,  1.71s/it]\u001b[A\n",
      " 25%|██████████████████████████████████████████▋                                                                                                                                   | 37/151 [01:20<03:49,  2.02s/it]\u001b[A\n",
      " 25%|███████████████████████████████████████████▊                                                                                                                                  | 38/151 [01:23<04:12,  2.24s/it]\u001b[A\n",
      " 26%|████████████████████████████████████████████▉                                                                                                                                 | 39/151 [01:25<04:18,  2.31s/it]\u001b[A\n",
      " 26%|██████████████████████████████████████████████                                                                                                                                | 40/151 [01:28<04:26,  2.40s/it]\u001b[A\n",
      " 27%|███████████████████████████████████████████████▏                                                                                                                              | 41/151 [01:29<03:37,  1.98s/it]\u001b[A\n",
      " 28%|████████████████████████████████████████████████▍                                                                                                                             | 42/151 [01:32<04:11,  2.31s/it]\u001b[A\n",
      " 28%|█████████████████████████████████████████████████▌                                                                                                                            | 43/151 [01:34<03:46,  2.10s/it]\u001b[A\n",
      " 29%|██████████████████████████████████████████████████▋                                                                                                                           | 44/151 [01:36<03:47,  2.12s/it]\u001b[A\n",
      " 30%|███████████████████████████████████████████████████▊                                                                                                                          | 45/151 [01:37<03:27,  1.96s/it]\u001b[A\n",
      " 30%|█████████████████████████████████████████████████████                                                                                                                         | 46/151 [01:40<03:47,  2.17s/it]\u001b[A\n",
      " 31%|██████████████████████████████████████████████████████▏                                                                                                                       | 47/151 [01:42<03:48,  2.19s/it]\u001b[A\n",
      " 32%|███████████████████████████████████████████████████████▎                                                                                                                      | 48/151 [01:45<03:55,  2.29s/it]\u001b[A\n",
      " 32%|████████████████████████████████████████████████████████▍                                                                                                                     | 49/151 [01:48<04:28,  2.63s/it]\u001b[A\n",
      " 33%|█████████████████████████████████████████████████████████▌                                                                                                                    | 50/151 [01:51<04:29,  2.67s/it]\u001b[A\n",
      " 34%|██████████████████████████████████████████████████████████▊                                                                                                                   | 51/151 [01:53<04:04,  2.45s/it]\u001b[A\n",
      " 34%|███████████████████████████████████████████████████████████▉                                                                                                                  | 52/151 [01:55<03:49,  2.32s/it]\u001b[A\n",
      " 35%|█████████████████████████████████████████████████████████████                                                                                                                 | 53/151 [01:56<03:06,  1.91s/it]\u001b[A\n",
      " 36%|██████████████████████████████████████████████████████████████▏                                                                                                               | 54/151 [01:59<03:37,  2.24s/it]\u001b[A\n",
      " 36%|███████████████████████████████████████████████████████████████▍                                                                                                              | 55/151 [02:02<03:53,  2.43s/it]\u001b[A\n",
      " 37%|████████████████████████████████████████████████████████████████▌                                                                                                             | 56/151 [02:03<03:18,  2.09s/it]\u001b[A\n",
      " 38%|█████████████████████████████████████████████████████████████████▋                                                                                                            | 57/151 [02:06<03:51,  2.46s/it]\u001b[A\n",
      " 38%|██████████████████████████████████████████████████████████████████▊                                                                                                           | 58/151 [02:08<03:29,  2.25s/it]\u001b[A\n",
      " 39%|███████████████████████████████████████████████████████████████████▉                                                                                                          | 59/151 [02:10<03:30,  2.29s/it]\u001b[A\n",
      " 40%|█████████████████████████████████████████████████████████████████████▏                                                                                                        | 60/151 [02:13<03:25,  2.26s/it]\u001b[A\n",
      " 40%|██████████████████████████████████████████████████████████████████████▎                                                                                                       | 61/151 [02:15<03:28,  2.31s/it]\u001b[A\n",
      " 41%|███████████████████████████████████████████████████████████████████████▍                                                                                                      | 62/151 [02:16<02:55,  1.98s/it]\u001b[A\n",
      " 42%|████████████████████████████████████████████████████████████████████████▌                                                                                                     | 63/151 [02:18<02:40,  1.83s/it]\u001b[A\n",
      " 42%|█████████████████████████████████████████████████████████████████████████▋                                                                                                    | 64/151 [02:20<03:02,  2.10s/it]\u001b[A\n",
      " 43%|██████████████████████████████████████████████████████████████████████████▉                                                                                                   | 65/151 [02:22<02:48,  1.96s/it]\u001b[A\n",
      " 44%|████████████████████████████████████████████████████████████████████████████                                                                                                  | 66/151 [02:25<02:58,  2.10s/it]\u001b[A\n",
      " 44%|█████████████████████████████████████████████████████████████████████████████▏                                                                                                | 67/151 [02:26<02:41,  1.93s/it]\u001b[A\n",
      " 45%|██████████████████████████████████████████████████████████████████████████████▎                                                                                               | 68/151 [02:28<02:43,  1.98s/it]\u001b[A\n",
      " 46%|███████████████████████████████████████████████████████████████████████████████▌                                                                                              | 69/151 [02:30<02:49,  2.07s/it]\u001b[A\n",
      " 46%|████████████████████████████████████████████████████████████████████████████████▋                                                                                             | 70/151 [02:32<02:35,  1.93s/it]\u001b[A\n",
      " 47%|█████████████████████████████████████████████████████████████████████████████████▊                                                                                            | 71/151 [02:34<02:39,  1.99s/it]\u001b[A\n",
      " 48%|██████████████████████████████████████████████████████████████████████████████████▉                                                                                           | 72/151 [02:36<02:36,  1.98s/it]\u001b[A\n",
      " 48%|████████████████████████████████████████████████████████████████████████████████████                                                                                          | 73/151 [02:39<02:53,  2.22s/it]\u001b[A\n",
      " 49%|█████████████████████████████████████████████████████████████████████████████████████▎                                                                                        | 74/151 [02:40<02:28,  1.93s/it]\u001b[A\n",
      " 50%|██████████████████████████████████████████████████████████████████████████████████████▍                                                                                       | 75/151 [02:42<02:28,  1.95s/it]\u001b[A\n",
      " 50%|███████████████████████████████████████████████████████████████████████████████████████▌                                                                                      | 76/151 [02:44<02:22,  1.91s/it]\u001b[A\n",
      " 51%|████████████████████████████████████████████████████████████████████████████████████████▋                                                                                     | 77/151 [02:46<02:21,  1.91s/it]\u001b[A\n",
      " 52%|█████████████████████████████████████████████████████████████████████████████████████████▉                                                                                    | 78/151 [02:48<02:28,  2.03s/it]\u001b[A\n",
      " 52%|███████████████████████████████████████████████████████████████████████████████████████████                                                                                   | 79/151 [02:49<02:10,  1.81s/it]\u001b[A\n",
      " 53%|████████████████████████████████████████████████████████████████████████████████████████████▏                                                                                 | 80/151 [02:51<01:58,  1.67s/it]\u001b[A\n",
      " 54%|█████████████████████████████████████████████████████████████████████████████████████████████▎                                                                                | 81/151 [02:54<02:31,  2.16s/it]\u001b[A\n",
      " 54%|██████████████████████████████████████████████████████████████████████████████████████████████▍                                                                               | 82/151 [02:56<02:18,  2.00s/it]\u001b[A\n",
      " 55%|███████████████████████████████████████████████████████████████████████████████████████████████▋                                                                              | 83/151 [02:58<02:15,  1.99s/it]\u001b[A\n",
      " 56%|████████████████████████████████████████████████████████████████████████████████████████████████▊                                                                             | 84/151 [03:00<02:20,  2.09s/it]\u001b[A\n",
      " 56%|█████████████████████████████████████████████████████████████████████████████████████████████████▉                                                                            | 85/151 [03:02<02:11,  1.99s/it]\u001b[A\n",
      " 57%|███████████████████████████████████████████████████████████████████████████████████████████████████                                                                           | 86/151 [03:03<01:54,  1.76s/it]\u001b[A\n",
      " 58%|████████████████████████████████████████████████████████████████████████████████████████████████████▎                                                                         | 87/151 [03:06<02:24,  2.26s/it]\u001b[A\n",
      " 58%|█████████████████████████████████████████████████████████████████████████████████████████████████████▍                                                                        | 88/151 [03:10<02:38,  2.52s/it]\u001b[A\n",
      " 59%|██████████████████████████████████████████████████████████████████████████████████████████████████████▌                                                                       | 89/151 [03:11<02:07,  2.06s/it]\u001b[A\n",
      " 60%|███████████████████████████████████████████████████████████████████████████████████████████████████████▋                                                                      | 90/151 [03:14<02:30,  2.47s/it]\u001b[A\n",
      " 60%|████████████████████████████████████████████████████████████████████████████████████████████████████████▊                                                                     | 91/151 [03:17<02:45,  2.76s/it]\u001b[A\n",
      " 61%|██████████████████████████████████████████████████████████████████████████████████████████████████████████                                                                    | 92/151 [03:20<02:37,  2.68s/it]\u001b[A\n",
      " 62%|███████████████████████████████████████████████████████████████████████████████████████████████████████████▏                                                                  | 93/151 [03:22<02:19,  2.40s/it]\u001b[A\n",
      " 62%|████████████████████████████████████████████████████████████████████████████████████████████████████████████▎                                                                 | 94/151 [03:24<02:15,  2.38s/it]\u001b[A\n",
      " 63%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████▍                                                                | 95/151 [03:26<02:08,  2.29s/it]\u001b[A\n",
      " 64%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████▌                                                               | 96/151 [03:27<01:40,  1.84s/it]\u001b[A\n",
      " 64%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████▊                                                              | 97/151 [03:28<01:31,  1.70s/it]\u001b[A\n",
      " 65%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████▉                                                             | 98/151 [03:30<01:37,  1.84s/it]\u001b[A\n",
      " 66%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████                                                            | 99/151 [03:31<01:19,  1.53s/it]\u001b[A\n",
      " 66%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████▌                                                          | 100/151 [03:34<01:33,  1.83s/it]\u001b[A\n",
      " 67%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████▋                                                         | 101/151 [03:36<01:39,  1.99s/it]\u001b[A\n",
      " 68%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▊                                                        | 102/151 [03:37<01:27,  1.79s/it]\u001b[A\n",
      " 68%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████                                                       | 103/151 [03:39<01:23,  1.75s/it]\u001b[A\n",
      " 69%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▏                                                     | 104/151 [03:42<01:39,  2.11s/it]\u001b[A\n",
      " 70%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▎                                                    | 105/151 [03:43<01:24,  1.84s/it]\u001b[A\n",
      " 70%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▍                                                   | 106/151 [03:45<01:22,  1.83s/it]\u001b[A\n",
      " 71%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▌                                                  | 107/151 [03:47<01:25,  1.94s/it]\u001b[A\n",
      " 72%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▋                                                 | 108/151 [03:50<01:35,  2.21s/it]\u001b[A\n",
      " 72%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▉                                                | 109/151 [03:53<01:43,  2.45s/it]\u001b[A\n",
      " 73%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████                                               | 110/151 [03:55<01:38,  2.39s/it]\u001b[A\n",
      " 74%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▏                                             | 111/151 [03:57<01:23,  2.08s/it]\u001b[A\n",
      " 74%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▎                                            | 112/151 [03:59<01:27,  2.25s/it]\u001b[A\n",
      " 75%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▍                                           | 113/151 [04:01<01:20,  2.13s/it]\u001b[A\n",
      " 75%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▌                                          | 114/151 [04:04<01:25,  2.32s/it]\u001b[A\n",
      " 76%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▊                                         | 115/151 [04:05<01:09,  1.93s/it]\u001b[A\n",
      " 77%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▉                                        | 116/151 [04:07<01:05,  1.88s/it]\u001b[A\n",
      " 77%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████                                       | 117/151 [04:09<01:06,  1.95s/it]\u001b[A\n",
      " 78%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▏                                     | 118/151 [04:11<01:09,  2.11s/it]\u001b[A\n",
      " 79%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▎                                    | 119/151 [04:13<01:03,  1.99s/it]\u001b[A\n",
      " 79%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▍                                   | 120/151 [04:16<01:13,  2.37s/it]\u001b[A\n",
      " 80%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▋                                  | 121/151 [04:17<00:57,  1.92s/it]\u001b[A\n",
      " 81%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▊                                 | 122/151 [04:20<01:06,  2.28s/it]\u001b[A\n",
      " 81%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▉                                | 123/151 [04:22<01:01,  2.19s/it]\u001b[A\n",
      " 82%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████                               | 124/151 [04:25<01:03,  2.34s/it]\u001b[A\n",
      " 83%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▏                             | 125/151 [04:26<00:49,  1.92s/it]\u001b[A\n",
      " 83%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▎                            | 126/151 [04:28<00:51,  2.08s/it]\u001b[A\n",
      " 84%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▌                           | 127/151 [04:29<00:41,  1.72s/it]\u001b[A\n",
      " 85%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▋                          | 128/151 [04:32<00:46,  2.04s/it]\u001b[A\n",
      " 85%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▊                         | 129/151 [04:35<00:48,  2.19s/it]\u001b[A\n",
      " 86%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▉                        | 130/151 [04:38<00:51,  2.46s/it]\u001b[A\n",
      " 87%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████                       | 131/151 [04:41<00:53,  2.65s/it]\u001b[A\n",
      " 87%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▏                     | 132/151 [04:43<00:46,  2.43s/it]\u001b[A\n",
      " 88%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▍                    | 133/151 [04:45<00:43,  2.39s/it]\u001b[A\n",
      " 89%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▌                   | 134/151 [04:48<00:43,  2.58s/it]\u001b[A\n",
      " 89%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▋                  | 135/151 [04:49<00:34,  2.17s/it]\u001b[A\n",
      " 90%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▊                 | 136/151 [04:51<00:30,  2.01s/it]\u001b[A\n",
      " 91%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▉                | 137/151 [04:53<00:26,  1.91s/it]\u001b[A\n",
      " 91%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████               | 138/151 [04:54<00:21,  1.66s/it]\u001b[A\n",
      " 92%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▎             | 139/151 [04:55<00:18,  1.52s/it]\u001b[A\n",
      " 93%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▍            | 140/151 [04:57<00:18,  1.66s/it]\u001b[A\n",
      " 93%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▌           | 141/151 [04:59<00:19,  1.97s/it]\u001b[A\n",
      " 94%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▋          | 142/151 [05:02<00:17,  1.99s/it]\u001b[A\n",
      " 95%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▊         | 143/151 [05:05<00:19,  2.45s/it]\u001b[A\n",
      " 95%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▉        | 144/151 [05:08<00:19,  2.75s/it]\u001b[A\n",
      " 96%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▏      | 145/151 [05:10<00:13,  2.24s/it]\u001b[A\n",
      " 97%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▎     | 146/151 [05:11<00:10,  2.11s/it]\u001b[A\n",
      " 97%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▍    | 147/151 [05:12<00:06,  1.72s/it]\u001b[A\n",
      " 98%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▌   | 148/151 [05:15<00:05,  1.91s/it]\u001b[A\n",
      " 99%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▋  | 149/151 [05:17<00:03,  1.98s/it]\u001b[A\n",
      " 99%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▊ | 150/151 [05:19<00:02,  2.00s/it]\u001b[A\n",
      "100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 151/151 [05:19<00:00,  2.12s/it]\u001b[A\n"
     ]
    }
   ],
   "source": [
    "i = 0\n",
    "with open('outputs/test_gpt_3.5_turbo_%s.txt' % task, 'w') as fd:\n",
    "    for q_ in tqdm(task_data['test'], total=len(task_data['test'])):\n",
    "        q = q_['input'] + '\\n'\n",
    "        for letter in ['A', 'B', 'C', 'D']:\n",
    "            q += '(' + letter + ') ' + q_[letter] + ' '\n",
    "        q += \"\\nA: Let's think step by step.\"  \n",
    "            \n",
    "        prompt_q = mmlu_prompt[task] + \"\\n\\n\" + q\n",
    "\n",
    "        response = completion_with_backoff(\n",
    "              model=\"gpt-3.5-turbo\",\n",
    "              messages=[\n",
    "                    {\"role\": \"system\", \"content\": \"Follow the given examples and answer the question.\"},\n",
    "                    {\"role\": \"user\", \"content\": prompt_q},\n",
    "                ],\n",
    "            temperature=0\n",
    "            )\n",
    "        ans_model = response['choices'][0]['message']['content']\n",
    "        ans_, residual = extract_ans(ans_model)\n",
    "            \n",
    "        a = q_['target']\n",
    "        fd.write('Q: %s\\nA_model:\\n%s\\nA:\\n%s\\n\\n' % (q, ans_, a))\n",
    "        i += 1\n",
    "        # if(i == 2): break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "8ea5012b-444a-49d8-8752-51ea485d9beb",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "astronomy\n",
      "num_q 151 correct 98 ratio 0.6490\n"
     ]
    }
   ],
   "source": [
    "print(task)\n",
    "_, _, _ = parse_pred_ans('outputs/test_gpt_3.5_turbo_%s.txt' % task)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
